package expressions

import (
	"context"
	"errors"
	"net"
	"strings"
	"testing"

	"github.com/TecharoHQ/anubis/internal/dns"
	"github.com/TecharoHQ/anubis/lib/store/memory"
	"github.com/google/cel-go/common/types"
	"github.com/google/cel-go/common/types/ref"
)

// newTestDNS is a helper function to create a new Dns object with an in-memory cache for testing.
func newTestDNS(forwardTTL int, reverseTTL int) *dns.Dns {
	ctx := context.Background()
	memStore := memory.New(ctx)
	cache := dns.NewDNSCache(forwardTTL, reverseTTL, memStore)
	return dns.New(ctx, cache)
}

func TestBotEnvironment(t *testing.T) {
	dnsObj := newTestDNS(300, 300)
	env, err := BotEnvironment(dnsObj)
	if err != nil {
		t.Fatalf("failed to create bot environment: %v", err)
	}

	t.Run("missingHeader", func(t *testing.T) {
		tests := []struct {
			headers     map[string]string
			name        string
			expression  string
			description string
			expected    types.Bool
		}{
			{
				name:       "missing-header",
				expression: `missingHeader(headers, "Missing-Header")`,
				headers: map[string]string{
					"User-Agent":   "test-agent",
					"Content-Type": "application/json",
				},
				expected:    types.Bool(true),
				description: "should return true when header is missing",
			},
			{
				name:       "existing-header",
				expression: `missingHeader(headers, "User-Agent")`,
				headers: map[string]string{
					"User-Agent":   "test-agent",
					"Content-Type": "application/json",
				},
				expected:    types.Bool(false),
				description: "should return false when header exists",
			},
			{
				name:       "case-sensitive",
				expression: `missingHeader(headers, "user-agent")`,
				headers: map[string]string{
					"User-Agent": "test-agent",
				},
				expected:    types.Bool(true),
				description: "should be case-sensitive (user-agent != User-Agent)",
			},
			{
				name:        "empty-headers",
				expression:  `missingHeader(headers, "Any-Header")`,
				headers:     map[string]string{},
				expected:    types.Bool(true),
				description: "should return true for any header when map is empty",
			},
			{
				name:       "real-world-sec-ch-ua",
				expression: `missingHeader(headers, "Sec-Ch-Ua")`,
				headers: map[string]string{
					"User-Agent": "curl/7.68.0",
					"Accept":     "*/*",
					"Host":       "example.com",
				},
				expected:    types.Bool(true),
				description: "should detect missing browser-specific headers from bots",
			},
			{
				name:       "browser-with-sec-ch-ua",
				expression: `missingHeader(headers, "Sec-Ch-Ua")`,
				headers: map[string]string{
					"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
					"Sec-Ch-Ua":  `"Chrome"; v="91", "Not A Brand"; v="99"`,
					"Accept":     "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
				},
				expected:    types.Bool(false),
				description: "should return false when browser sends Sec-Ch-Ua header",
			},
		}

		for _, tt := range tests {
			t.Run(tt.name, func(t *testing.T) {
				prog, err := Compile(env, tt.expression)
				if err != nil {
					t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
				}

				result, _, err := prog.Eval(map[string]interface{}{
					"headers": tt.headers,
				})
				if err != nil {
					t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
				}

				if result != tt.expected {
					t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
				}
			})
		}

		t.Run("function-compilation", func(t *testing.T) {
			src := `missingHeader(headers, "Test-Header")`
			_, err := Compile(env, src)
			if err != nil {
				t.Fatalf("failed to compile missingHeader expression: %v", err)
			}
		})
	})

	t.Run("segments", func(t *testing.T) {
		for _, tt := range []struct {
			name        string
			description string
			expression  string
			path        string
			expected    types.Bool
		}{
			{
				name:        "simple",
				description: "/ should have one path segment",
				expression:  `size(segments(path)) == 1`,
				path:        "/",
				expected:    types.Bool(true),
			},
			{
				name:        "two segments without trailing slash",
				description: "/user/foo should have two segments",
				expression:  `size(segments(path)) == 2`,
				path:        "/user/foo",
				expected:    types.Bool(true),
			},
			{
				name:        "at least two segments",
				description: "/foo/bar/ should have at least two path segments",
				expression:  `size(segments(path)) >= 2`,
				path:        "/foo/bar/",
				expected:    types.Bool(true),
			},
			{
				name:        "at most two segments",
				description: "/foo/bar/ does not have less than two path segments",
				expression:  `size(segments(path)) < 2`,
				path:        "/foo/bar/",
				expected:    types.Bool(false),
			},
		} {
			t.Run(tt.name, func(t *testing.T) {
				prog, err := Compile(env, tt.expression)
				if err != nil {
					t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
				}

				result, _, err := prog.Eval(map[string]interface{}{
					"path": tt.path,
				})
				if err != nil {
					t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
				}

				if result != tt.expected {
					t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
				}
			})
		}

		t.Run("invalid", func(t *testing.T) {
			for _, tt := range []struct {
				env             any
				name            string
				description     string
				expression      string
				wantFailCompile bool
				wantFailEval    bool
			}{
				{
					name:        "segments of headers",
					description: "headers are not a path list",
					expression:  `segments(headers)`,
					env: map[string]any{
						"headers": map[string]string{
							"foo": "bar",
						},
					},
					wantFailCompile: true,
				},
				{
					name:        "invalid path type",
					description: "a path should be a sting",
					expression:  `size(segments(path)) != 0`,
					env: map[string]any{
						"path": 4,
					},
					wantFailEval: true,
				},
				{
					name:        "invalid path",
					description: "a path should start with a leading slash",
					expression:  `size(segments(path)) != 0`,
					env: map[string]any{
						"path": "foo",
					},
					wantFailEval: true,
				},
			} {
				t.Run(tt.name, func(t *testing.T) {
					prog, err := Compile(env, tt.expression)
					if err != nil {
						if !tt.wantFailCompile {
							t.Log(tt.description)
							t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
						} else {
							return
						}
					}

					_, _, err = prog.Eval(tt.env)

					if err == nil {
						t.Log(tt.description)
						t.Fatal("wanted an error but got none")
					}

					t.Log(err)
				})
			}
		})

		t.Run("function-compilation", func(t *testing.T) {
			src := `size(segments(path)) <= 2`
			_, err := Compile(env, src)
			if err != nil {
				t.Fatalf("failed to compile missingHeader expression: %v", err)
			}
		})
	})

	t.Run("regexSafe", func(t *testing.T) {
		tests := []struct {
			name        string
			expression  string
			expected    types.String
			description string
		}{
			{
				name:        "complex-test",
				expression:  `regexSafe("^(test1|test2|)[a-z]+$")`,
				expected:    types.String("\\^\\(test1\\|test2\\|\\)\\[a\\-z\\]\\+\\$"),
				description: "should escape all reserved regex characters",
			},
			{
				name:        "backslash-test",
				expression:  `regexSafe("use \\\\ for special characters escaping\t, one/\"\\\"/for/cel and one/for/regex")`,
				expected:    types.String("use \\\\\\\\ for special characters escaping\t, one/\"\\\\\"/for/cel and one/for/regex"),
				description: "should escape double-backslashes as double-double-backslashes and ignore cel escaping and forward slashes",
			},
		}

		for _, tt := range tests {
			t.Run(tt.name, func(t *testing.T) {
				prog, err := Compile(env, tt.expression)
				if err != nil {
					t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
				}

				result, _, err := prog.Eval(map[string]interface{}{})
				if err != nil {
					t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
				}

				if result != tt.expected {
					t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
				}
			})
		}

		t.Run("function-compilation", func(t *testing.T) {
			src := `regexSafe(".*")`
			_, err := Compile(env, src)
			if err != nil {
				t.Fatalf("failed to compile regexSafe expression: %v", err)
			}
		})
	})

	t.Run("dnsFunctions", func(t *testing.T) {
		originalDNSLookupAddr := dns.DNSLookupAddr
		originalDNSLookupHost := dns.DNSLookupHost
		defer func() {
			dns.DNSLookupAddr = originalDNSLookupAddr
			dns.DNSLookupHost = originalDNSLookupHost
		}()

		t.Run("reverseDNS", func(t *testing.T) {
			tests := []struct {
				name        string
				addr        string
				mockReturn  []string
				mockError   error
				expression  string
				expected    ref.Val
				description string
			}{
				{
					name:        "success",
					addr:        "8.8.8.8",
					mockReturn:  []string{"dns.google."},
					expression:  `reverseDNS("8.8.8.8")`,
					expected:    types.NewStringList(types.DefaultTypeAdapter, []string{"dns.google"}),
					description: "should return domain names for an IP",
				},
				{
					name:        "not-found",
					addr:        "127.0.0.1",
					mockReturn:  []string{},
					mockError:   &net.DNSError{IsNotFound: true},
					expression:  `reverseDNS("127.0.0.1")`,
					expected:    types.NewStringList(types.DefaultTypeAdapter, []string{}),
					description: "should return an empty list when not found",
				},
				{
					name:        "error",
					addr:        "error-addr",
					mockError:   errors.New("some dns error"),
					expression:  `reverseDNS("error-addr")`,
					expected:    types.NewStringList(types.DefaultTypeAdapter, []string{}),
					description: "should return empty list on error",
				},
			}

			for _, tt := range tests {
				t.Run(tt.name, func(t *testing.T) {
					dns.DNSLookupAddr = func(addr string) ([]string, error) {
						if addr == tt.addr {
							return tt.mockReturn, tt.mockError
						}
						return nil, errors.New("unexpected address for reverse lookup")
					}

					prog, err := Compile(env, tt.expression)
					if err != nil {
						t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
					}

					result, _, err := prog.Eval(map[string]interface{}{})
					if err != nil {
						t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
					}
					if result.Equal(tt.expected) != types.True {
						t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
					}
				})
			}
		})

		t.Run("lookupHost", func(t *testing.T) {
			tests := []struct {
				name        string
				host        string
				mockReturn  []string
				mockError   error
				expression  string
				expected    ref.Val
				description string
			}{
				{
					name:        "success",
					host:        "dns.google",
					mockReturn:  []string{"8.8.8.8", "8.8.4.4"},
					expression:  `lookupHost("dns.google")`,
					expected:    types.NewStringList(types.DefaultTypeAdapter, []string{"8.8.8.8", "8.8.4.4"}),
					description: "should return IPs for a domain name",
				},
				{
					name:        "not-found",
					host:        "nonexistent.domain.example.com",
					mockReturn:  []string{},
					mockError:   &net.DNSError{IsNotFound: true},
					expression:  `lookupHost("nonexistent.domain.example.com")`,
					expected:    types.NewStringList(types.DefaultTypeAdapter, []string{}),
					description: "should return an empty list when not found",
				},
				{
					name:        "error",
					host:        "error-host",
					mockError:   errors.New("some dns error"),
					expression:  `lookupHost("error-host")`,
					expected:    types.NewStringList(types.DefaultTypeAdapter, []string{}),
					description: "should return empty list on error",
				},
			}

			for _, tt := range tests {
				t.Run(tt.name, func(t *testing.T) {
					dns.DNSLookupHost = func(host string) ([]string, error) {
						if host == tt.host {
							return tt.mockReturn, tt.mockError
						}
						return nil, errors.New("unexpected host for forward lookup")
					}

					prog, err := Compile(env, tt.expression)
					if err != nil {
						t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
					}

					result, _, err := prog.Eval(map[string]interface{}{})
					if err != nil {
						t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
					}
					if result.Equal(tt.expected) != types.True {
						t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
					}
				})
			}
		})

		t.Run("verifyFCrDNS", func(t *testing.T) {
			tests := []struct {
				name              string
				addr              string
				reverseMockReturn []string
				reverseMockError  error
				forwardMockReturn map[string][]string // name -> ips
				forwardMockError  map[string]error
				expression        string
				expected          types.Bool
				description       string
			}{
				{
					name:              "success",
					addr:              "8.8.8.8",
					reverseMockReturn: []string{"dns.google."},
					forwardMockReturn: map[string][]string{"dns.google": {"8.8.8.8", "8.8.4.4"}},
					expression:        `verifyFCrDNS("8.8.8.8")`,
					expected:          types.Bool(true),
					description:       "should return true for valid FCrDNS",
				},
				{
					name:              "failure",
					addr:              "1.2.3.4",
					reverseMockReturn: []string{"spoofed.example.com."},
					forwardMockReturn: map[string][]string{"spoofed.example.com": {"5.6.7.8"}},
					expression:        `verifyFCrDNS("1.2.3.4")`,
					expected:          types.Bool(false),
					description:       "should return false for invalid FCrDNS",
				},
				{
					name:             "reverse-lookup-fails",
					addr:             "1.1.1.1",
					reverseMockError: errors.New("reverse lookup failed"),
					expression:       `verifyFCrDNS("1.1.1.1")`,
					expected:         types.Bool(false),
					description:      "should return false if reverse lookup fails",
				},
				{
					name:              "success-with-pattern",
					addr:              "8.8.8.8",
					reverseMockReturn: []string{"dns.google."},
					forwardMockReturn: map[string][]string{"dns.google": {"8.8.8.8"}},
					expression:        `verifyFCrDNS("8.8.8.8", "dns.google")`,
					expected:          types.Bool(true),
					description:       "should return true for valid FCrDNS with matching pattern",
				},
				{
					name:              "failure-with-pattern",
					addr:              "8.8.8.8",
					reverseMockReturn: []string{"dns.google."},
					forwardMockReturn: map[string][]string{"dns.google": {"8.8.8.8"}},
					expression:        `verifyFCrDNS("8.8.8.8", "wrong.pattern")`,
					expected:          types.Bool(false),
					description:       "should return false for FCrDNS with non-matching pattern",
				},
			}

			for _, tt := range tests {
				t.Run(tt.name, func(t *testing.T) {
					dns.DNSLookupAddr = func(addr string) ([]string, error) {
						if addr == tt.addr {
							return tt.reverseMockReturn, tt.reverseMockError
						}
						return nil, errors.New("unexpected address for reverse lookup")
					}
					dns.DNSLookupHost = func(host string) ([]string, error) {
						host = strings.TrimSuffix(host, ".")
						if ips, ok := tt.forwardMockReturn[host]; ok {
							return ips, nil
						}
						if err, ok := tt.forwardMockError[host]; ok {
							return nil, err
						}
						return nil, &net.DNSError{IsNotFound: true}
					}

					prog, err := Compile(env, tt.expression)
					if err != nil {
						t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
					}

					result, _, err := prog.Eval(map[string]interface{}{})
					if err != nil {
						t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
					}
					if result.Equal(tt.expected) != types.True {
						t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
					}
				})
			}
		})

		t.Run("arpaReverseIP", func(t *testing.T) {
			tests := []struct {
				name        string
				expression  string
				expected    types.String
				description string
				evalError   bool
			}{
				{
					name:        "ipv4",
					expression:  `arpaReverseIP("1.2.3.4")`,
					expected:    types.String("4.3.2.1"),
					description: "should correctly reverse an IPv4 address",
				},
				{
					name:        "ipv6",
					expression:  `arpaReverseIP("2001:db8::1")`,
					expected:    types.String("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2"),
					description: "should correctly reverse an IPv6 address",
				},
				{
					name:        "ipv6-full",
					expression:  `arpaReverseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")`,
					expected:    types.String("4.3.3.7.0.7.3.0.e.2.a.8.0.0.0.0.0.0.0.0.3.a.5.8.8.b.d.0.1.0.0.2"),
					description: "should correctly reverse a fully expanded IPv6 address",
				},
				{
					name:        "ipv6-loopback",
					expression:  `arpaReverseIP("::1")`,
					expected:    types.String("1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0"),
					description: "should correctly reverse the IPv6 loopback address",
				},
				{
					name:        "invalid-ip",
					expression:  `arpaReverseIP("not-an-ip")`,
					evalError:   true,
					description: "should error on an invalid IP",
				},
			}

			for _, tt := range tests {
				t.Run(tt.name, func(t *testing.T) {
					prog, err := Compile(env, tt.expression)
					if err != nil {
						t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
					}

					result, _, err := prog.Eval(map[string]interface{}{})
					if tt.evalError {
						if err == nil {
							t.Errorf("%s: expected an evaluation error, but got none", tt.description)
						}
						return
					}
					if err != nil {
						t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
					}
					if result.Equal(tt.expected) != types.True {
						t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
					}
				})
			}
		})
	})
}

func TestThresholdEnvironment(t *testing.T) {
	env, err := ThresholdEnvironment()
	if err != nil {
		t.Fatalf("failed to create threshold environment: %v", err)
	}

	tests := []struct {
		variables     map[string]interface{}
		name          string
		expression    string
		description   string
		expected      types.Bool
		shouldCompile bool
	}{
		{
			name:          "weight-variable-available",
			expression:    `weight > 100`,
			variables:     map[string]interface{}{"weight": 150},
			expected:      types.Bool(true),
			description:   "should support weight variable in expressions",
			shouldCompile: true,
		},
		{
			name:          "weight-variable-false-case",
			expression:    `weight > 100`,
			variables:     map[string]interface{}{"weight": 50},
			expected:      types.Bool(false),
			description:   "should correctly evaluate weight comparisons",
			shouldCompile: true,
		},
		{
			name:          "missingHeader-not-available",
			expression:    `missingHeader(headers, "Test")`,
			variables:     map[string]interface{}{},
			expected:      types.Bool(false), // not used
			description:   "should not have missingHeader function available",
			shouldCompile: false,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			prog, err := Compile(env, tt.expression)

			if !tt.shouldCompile {
				if err == nil {
					t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
				}
				return // Test passed - compilation failed as expected
			}

			if err != nil {
				t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
			}

			result, _, err := prog.Eval(tt.variables)
			if err != nil {
				t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
			}

			if result != tt.expected {
				t.Errorf("%s: expected %v, got %v", tt.description, tt.expected, result)
			}
		})
	}
}

func TestNewEnvironment(t *testing.T) {
	env, err := New()
	if err != nil {
		t.Fatalf("failed to create new environment: %v", err)
	}

	tests := []struct {
		name          string
		expression    string
		variables     map[string]interface{}
		expectBool    *bool // nil if we just want to test compilation or non-bool result
		description   string
		shouldCompile bool
	}{
		{
			name:          "randInt-function-compilation",
			expression:    `randInt(10)`,
			variables:     map[string]interface{}{},
			expectBool:    nil, // Don't check result, just compilation
			description:   "should compile randInt function",
			shouldCompile: true,
		},
		{
			name:          "randInt-range-validation",
			expression:    `randInt(10) >= 0 && randInt(10) < 10`,
			variables:     map[string]interface{}{},
			expectBool:    boolPtr(true),
			description:   "should return values in correct range",
			shouldCompile: true,
		},
		{
			name:          "strings-extension-size",
			expression:    `"hello".size() == 5`,
			variables:     map[string]interface{}{},
			expectBool:    boolPtr(true),
			description:   "should support string extension functions",
			shouldCompile: true,
		},
		{
			name:          "strings-extension-contains",
			expression:    `"hello world".contains("world")`,
			variables:     map[string]interface{}{},
			expectBool:    boolPtr(true),
			description:   "should support string contains function",
			shouldCompile: true,
		},
		{
			name:          "strings-extension-startsWith",
			expression:    `"hello world".startsWith("hello")`,
			variables:     map[string]interface{}{},
			expectBool:    boolPtr(true),
			description:   "should support string startsWith function",
			shouldCompile: true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			prog, err := Compile(env, tt.expression)

			if !tt.shouldCompile {
				if err == nil {
					t.Fatalf("%s: expected compilation to fail but it succeeded", tt.description)
				}
				return // Test passed - compilation failed as expected
			}

			if err != nil {
				t.Fatalf("failed to compile expression %q: %v", tt.expression, err)
			}

			// If we only want to test compilation, skip evaluation
			if tt.expectBool == nil {
				return
			}

			result, _, err := prog.Eval(tt.variables)
			if err != nil {
				t.Fatalf("failed to evaluate expression %q: %v", tt.expression, err)
			}

			if result != types.Bool(*tt.expectBool) {
				t.Errorf("%s: expected %v, got %v", tt.description, *tt.expectBool, result)
			}
		})
	}
}

// Helper function to create bool pointers
func boolPtr(b bool) *bool {
	return &b
}
